const std = @import("std");

const stdx = @import("stdx");

const aio = @import("aio.zig");
const Buffer = @import("buffer.zig").Buffer;
const Bufferized = @import("tensor.zig").Bufferized;
const callback = @import("callback.zig");
const CompilationContext = @import("module.zig").CompilationContext;
const meta = @import("meta.zig");
const pjrt = @import("pjrtx.zig");
const Platform = @import("platform.zig").Platform;
const Shape = @import("shape.zig").Shape;
const ShapeOf = @import("tensor.zig").ShapeOf;

const log = std.log.scoped(.@"zml/exe");

test {
    std.testing.refAllDecls(@This());
}

/// Compiles a Model struct with the given configuration and shapes, for the given platform.
/// The steps are:
/// * lookup at tensors available in the store and create a `model: Model` struct with them
/// * call `model.init(init_args)` to fields of the model that aren't Tensor, ie hyperparemeters/config
/// * generate MLIR by calling `model.forward` with tensor of the given shapes and other arguments
pub fn compile(
    allocator: std.mem.Allocator,
    comptime func: anytype,
    init_args: anytype,
    args_shapes: ShapeOf(ModuleSignature(func).ArgsT),
    buffer_store: aio.BufferStore,
    platform: Platform,
) !FnExe(func) {
    return compileWithPrefix(allocator, func, init_args, args_shapes, buffer_store, platform, "");
}

/// Compiles a Model struct with the given configuration and shapes, for the given platform.
/// Uses a prefix for looking up model weights in the buffer store.
/// The steps are:
/// * lookup at tensors available in the store and create a `model: Model` struct with them
/// * call `model.init(init_args)` to fields of the model that aren't Tensor, ie hyperparemeters/config
/// * generate MLIR by calling `model.forward` with tensor of the given shapes and other arguments
pub fn compileWithPrefix(
    allocator: std.mem.Allocator,
    comptime func: anytype,
    init_args: anytype,
    args_shapes: ShapeOf(ModuleSignature(func).ArgsT),
    buffer_store: aio.BufferStore,
    platform: Platform,
    prefix: []const u8,
) !FnExe(func) {
    const ModelT = ModuleSignature(func).ModelT;

    var arena_state = std.heap.ArenaAllocator.init(allocator);
    defer arena_state.deinit();
    const arena = arena_state.allocator();
    var model = try aio.populateModelWithPrefix(ModelT, arena, buffer_store, prefix);

    // If the Model has a "init" function, call it with the given parameters.
    if (@hasDecl(ModelT, "init")) {
        // TODO(Corentin,@Improvement): Add a warning/error if there is no init function but init_args is non-void.
        @call(.auto, ModelT.init, .{@as(*ModelT, &model)} ++ init_args);
    }

    return compileModel(allocator, func, model, args_shapes, platform);
}

/// Compiles a Model struct with the given configuration and shapes, for the given platform.
/// Generate MLIR by calling `model.forward` with tensor of the given shapes and other arguments
pub fn compileModel(
    allocator: std.mem.Allocator,
    comptime func: anytype,
    model: ModuleSignature(func).ModelT,
    args_shapes: ShapeOf(ModuleSignature(func).ArgsT),
    platform: Platform,
) !FnExe(func) {
    const ModelT = ModuleSignature(func).ModelT;
    const name = @typeName(ModelT) ++ ".forward";
    log.info("Compiling {s} with {f}", .{ name, stdx.fmt.any(args_shapes) });

    var context = try CompilationContext.init(allocator, name, platform);
    defer context.deinit();

    return .{ .inner = try context.compileInternal(allocator, func, .{model} ++ args_shapes) };
}

/// Compiles a function with the given configuration and shapes, for the given platform.
/// Generate MLIR by calling the given function with tensor of the given shapes.
pub fn compileFn(
    allocator: std.mem.Allocator,
    comptime func: anytype,
    args: ShapeOf(stdx.meta.FnArgs(func)),
    platform: Platform,
) !FnExe(func) {
    var pretty_name = try prettyFnName(func, allocator);
    defer pretty_name.deinit(allocator);
    log.info("Compiling {s} with {f}", .{ pretty_name.items, stdx.fmt.any(args) });

    var context = try CompilationContext.init(allocator, pretty_name.items, platform);
    defer context.deinit();

    return .{ .inner = try context.compileInternal(allocator, func, args) };
}

pub fn FnExe(comptime func: anytype) type {
    return Exe(stdx.meta.FnArgs(func), stdx.meta.FnResult(func));
}

/// Represents a ZML model, compiled into a PJRT executable, and ready to call.
/// The buffers for the model weights are saved inside the struct and will be used in `call`.
/// You only need to pass the remaining arguments.
/// Creating a `ModuleExe` is a two steps proccess:
///
/// ```
/// const exe: zml.FnExe(MyModel.forward) = try zml.compile(allocator, MyModel.forward, init_args, model_shapes, buffer_store, platform);`
/// const module: zml.ModuleExe(MyModel.forward) = exe.prepare(model_buffers);
/// ```
pub fn ModuleExe(comptime func: anytype) type {
    const AllArgs = stdx.meta.FnArgs(func);
    const len = @typeInfo(AllArgs).@"struct".fields.len;
    stdx.debug.assertComptime(len > 0, "ModuleExe expects a function with at least one argument where the first one is treated as the module, got {}", .{func});
    return Exe(stdx.meta.Tail(AllArgs), stdx.meta.FnResult(func));
}

// making this a struct force all fields to be evaluted on creation,
// which gives a better error stacktrace
// than delaying the error to when the object fields are read.
const Sign = struct {
    ModelT: type,
    ArgsT: type,
    ReturnT: type,
};

pub fn ModuleSignature(comptime func: anytype) Sign {
    const AllArgsT = stdx.meta.FnArgs(func);
    const len = @typeInfo(AllArgsT).@"struct".fields.len;
    stdx.debug.assertComptime(len > 0, "ModuleExe expects a function with at least one argument where the first one is treated as the module, got {}", .{func});

    return .{
        .ModelT = stdx.meta.Head(AllArgsT),
        .ArgsT = stdx.meta.Tail(AllArgsT),
        .ReturnT = stdx.meta.FnResult(func),
    };
}

/// Represents an MLIR module compiled into a PJRT executable.
/// The BaseExe is a plain old struct and doesn't have information about Zig types.
///
/// It also contains pre-allocated buffers so that we can pass them to PJRT_LoadedExecutable_Execute
/// without allocations.
pub const BaseExe = struct {
    /// The platform for which this module was compiled.
    platform: Platform,

    /// The PJRT executable representing the compiled module.
    exe: *pjrt.LoadedExecutable,

    /// The execution context for this executable.
    execute_context: ?*pjrt.ExecuteContext,

    /// Pre-allocated slice of buffers to use as inputs when the module is called.
    input_per_device: []const [*]*pjrt.Buffer,

    /// Pre-allocated slice of buffers to use as outputs when the module is called.
    output_per_device: []const [*]*pjrt.Buffer,

    /// Number of buffers already fed to the executable.
    ready_buffer_count: u32,

    /// Total number of buffers needed by this executable.
    input_buffer_count: u32,

    input_shapes: []Shape,
    result_shapes: []Shape,

    /// Num devices used (>1 for sharded executable)
    num_devices: u8,

    /// Allocator backing memory
    _arena: std.heap.ArenaAllocator,

    pub fn init(
        parent_allocator: std.mem.Allocator,
        platform: Platform,
        exe: *pjrt.LoadedExecutable,
        args: struct { input_shapes: []const Shape, result_shapes: []const Shape, n_devices: u8 },
    ) !BaseExe {
        var arena = std.heap.ArenaAllocator.init(parent_allocator);
        errdefer arena.deinit();
        const allocator = arena.allocator();
        const n_in = args.input_shapes.len;
        const n_out = args.result_shapes.len;
        const n_devices = args.n_devices;
        // Allocate once for all the *pjrt.Buffer we need to store ...
        const all_buffers = try allocator.alloc(*pjrt.Buffer, (n_in + n_out) * n_devices);
        const all_input_buffers, const all_output_buffers = splitBuffer(*pjrt.Buffer, all_buffers, .{ n_in * n_devices, n_out * n_devices });

        // ... and once for all the [*]*pjrt.Buffer.
        const all_per_device = try allocator.alloc([*]*pjrt.Buffer, 2 * n_devices);
        const input_per_device, const output_per_device = splitBuffer([*]*pjrt.Buffer, all_per_device, .{ n_devices, n_devices });

        for (0..n_devices) |i| {
            input_per_device[i] = all_input_buffers[i * n_in ..].ptr;
            output_per_device[i] = all_output_buffers[i * n_out ..].ptr;
        }

        const all_shapes = try allocator.alloc(Shape, n_in + n_out);
        @memcpy(all_shapes[0..n_in], args.input_shapes);
        @memcpy(all_shapes[n_in..], args.result_shapes);

        var execute_context: ?*pjrt.ExecuteContext = null;
        if (platform.pjrt_api.ffi()) |ffi| {
            execute_context = try platform.pjrt_api.createExecuteContext();
            try callback.bindInternalCallbacks(allocator, platform, ffi, execute_context.?);
            // log.info("Created context execution {*} for {*}", .{ execute_context, exe });
        }

        return .{
            .platform = platform,
            .exe = exe,
            .execute_context = execute_context,
            .ready_buffer_count = 0,
            .input_buffer_count = @intCast(n_in),
            .num_devices = args.n_devices,
            .input_per_device = input_per_device,
            .output_per_device = output_per_device,
            .input_shapes = all_shapes[0..n_in],
            .result_shapes = all_shapes[n_in..],
            ._arena = arena,
        };
    }

    pub fn deinit(self: BaseExe) void {
        if (self.execute_context) |ctx| {
            ctx.deinit(self.platform.pjrt_api);
        }
        self.exe.deinit(self.platform.pjrt_api);
        self._arena.deinit();
    }

    pub fn call(self: BaseExe) void {
        stdx.debug.assert(self.input_buffer_count == self.ready_buffer_count, "BaseExe isn't ready to be called, expected {} buffer inputs got {}", .{ self.input_buffer_count, self.ready_buffer_count });
        return self._unsafeCall();
    }

    pub fn _unsafeCall(self: BaseExe) void {
        var events = [_]?*pjrt.Event{null} ** Platform.MAX_NUM_DEVICES;
        const sharding = self.platform.sharding();

        self.exe.execute(self.platform.pjrt_api, .{
            .arguments = self.input_per_device,
            .num_args = self.input_buffer_count,
            .results = self.output_per_device,
            .events = events[0..sharding.num_partitions],
            // this allows to tell a specific buffer shouldn't be donated,
            // even if it has been marked as "can be donated" during compilation.
            // TODO: expose it ?
            .non_donatable_input_indices = &.{},
            .context = self.execute_context,
        }) catch |err| {
            std.debug.panic("PJRT_LoadedExecutable_Execute failed with: {}", .{err});
        };

        // for (events[0..sharding.num_partitions]) |e| {
        //     if (e) |ev| {
        //         ev.await(self.platform.pjrt_api) catch unreachable;
        //     }
        // }
    }

    pub fn _unsafeAssignResults(self: BaseExe, T: type, result: *T) void {
        const LocalContext = struct {
            index: u32,
            platform: Platform,
            outputs: []const [*]*pjrt.Buffer,
            output_shapes: []Shape,
        };
        var local_ctx: LocalContext = .{
            .index = 0,
            .platform = self.platform,
            .outputs = self.output_per_device,
            .output_shapes = self.result_shapes,
        };
        meta.visit((struct {
            fn cb(ctx: *LocalContext, buffer: *Buffer) void {
                const i = ctx.index;
                ctx.index += 1;
                if (i >= ctx.output_shapes.len) return;

                var shards: Buffer.Shards = .{};
                for (ctx.outputs) |buff| {
                    shards.appendAssumeCapacity(buff[i]);
                }
                buffer.* = Buffer.fromPjrtBuffers(ctx.platform, ctx.output_shapes[i], shards.constSlice());
            }
        }).cb, &local_ctx, result);
        stdx.debug.internalAssert(local_ctx.index == self.result_shapes.len, "Pjrt call returned {} tensors, but the return type {s}, contains {} Buffers. Note that modules need to have a comptime know number of returned tensors.", .{ self.output_per_device.len, @typeName(T), local_ctx.index });
    }

    pub fn bind(exe: BaseExe, Callback: type, op: *Callback) !void {
        stdx.debug.assert(exe.execute_context != null, "Exe doesn't have an execution context", .{});
        const pjrt_api = exe.platform.pjrt_api;

        if (pjrt_api.ffi()) |ffi| {
            try callback.addUserData(Callback, pjrt_api, ffi, exe.execute_context.?, op);
        } else {
            stdx.debug.panic("Callbacks are not supported for target {s}", .{@tagName(exe.platform.target)});
        }
    }

    pub fn serialize(self: BaseExe, writer: *std.Io.Writer) !void {
        var executable = try self.exe.getExecutable(self.platform.pjrt_api);
        var serialize_result = try executable.serialize(self.platform.pjrt_api);
        defer serialize_result.deinit();
        try writer.writeAll(serialize_result.bytes);
    }

    // pub fn deserialize(allocator: std.mem.Allocator, platform: Platform, reader: anytype) !Self {
    //     const bytes = try reader.readToEndAlloc(allocator, max_pjrt_executable_size);
    //     defer allocator.free(bytes);
    //     return platform.pjrt_client.deserializeAndLoad(platform.pjrt_api, bytes);
    // }

    pub fn prepare(self: *BaseExe, x: anytype) void {
        const n = fillBuffers(&x, self.input_shapes, self.input_per_device, self.ready_buffer_count);
        self.ready_buffer_count += n;
    }

    pub fn getOutputBuffer(self: BaseExe, i: usize) Buffer {
        var shards: Buffer.Shards = .{};
        for (self.output_per_device) |dev_out| {
            shards.appendAssumeCapacity(dev_out[i]);
        }

        return Buffer.fromPjrtBuffers(self.platform, self.result_shapes[i], shards.constSlice());
    }

    pub fn clone(self: BaseExe, parent_allocator: std.mem.Allocator) !BaseExe {
        var exe: BaseExe = try .init(parent_allocator, self.platform, self.exe, .{
            .input_shapes = self.input_shapes,
            .result_shapes = self.result_shapes,
            .n_devices = self.num_devices,
        });
        exe.execute_context = self.execute_context;
        return exe;
    }
};

/// Represents a ZML function, compiled into a PJRT executable.
/// The signature of the Exe reflects the arguments that are needed for `call`.
pub fn Exe(ArgsT: type, ReturnT: type) type {
    return struct {
        const Self = @This();

        /// The raw untyped compiled module.
        inner: BaseExe,

        pub fn deinit(self: Self) void {
            self.inner.deinit();
        }

        /// Hardcode the first argument of the function to the given buffers.
        /// Returns an Exe with one less argument in `call`.
        /// In functional languages this is known as partial application.
        ///
        /// **Warning:** the new Exe reuses the underlying memory of the previous one.
        /// The caller is responsible to come up with a strategy to call `deinit` exactly once.
        pub fn prepare(self: Self, first_arg: Bufferized(stdx.meta.Head(ArgsT))) Exe(stdx.meta.Tail(ArgsT), ReturnT) {
            var new: Exe(stdx.meta.Tail(ArgsT), ReturnT) = .{ .inner = self.inner };
            new.inner.prepare(first_arg);
            return new;
        }

        /// For a given customCall inside this executable,
        /// provide a pointer to runtime data.
        /// The caller keeps memory ownership and need to ensure that the value
        /// stays alive as long as the executable.
        pub fn bind(self: Self, comptime T: type, value: *T) !void {
            try self.inner.bind(T, value);
        }

        pub fn serialize(self: Self, writer: *std.Io.Writer) !void {
            return try self.inner.serialize(writer);
        }

        pub fn platform(self: Self) Platform {
            return self.inner.platform;
        }

        pub fn call(self: Self, args: Bufferized(ArgsT)) Bufferized(ReturnT) {
            const total_ready = fillBuffers(&args, self.inner.input_shapes, self.inner.input_per_device, self.inner.ready_buffer_count);
            std.debug.assert(total_ready == self.inner.input_buffer_count);
            self.inner._unsafeCall();
            var result: Bufferized(ReturnT) = undefined;
            self.inner._unsafeAssignResults(Bufferized(ReturnT), &result);
            return result;
        }

        pub fn clone(self: Self, allocator: std.mem.Allocator) !Self {
            return .{ .inner = try self.inner.clone(allocator) };
        }
    };
}

fn splitBuffer(T: type, buffer: []T, lengths: anytype) [lengths.len][]T {
    var res: [lengths.len][]T = undefined;
    var i: usize = 0;
    inline for (&res, lengths) |*r, len| {
        r.* = buffer[i .. i + len];
        i += len;
    }
    std.debug.assert(i == buffer.len);
    return res;
}

/// Visit the given struct and fill the `buffers` slice with the buffer associated with encountered Tensor.
fn fillBuffers(v: anytype, shapes: []const Shape, buffers: []const [*]*pjrt.Buffer, start: u32) u32 {
    const LocalContext = struct {
        index: u32,
        buffers: []const [*]*pjrt.Buffer,
        shapes: []const Shape,
    };
    var context: LocalContext = .{
        .index = start,
        .buffers = buffers,
        .shapes = shapes,
    };
    meta.visit((struct {
        fn cb(ctx: *LocalContext, buffer: *const Buffer) void {
            // stdx.debug.assert(!buffer._data.isDeleted(), "Can't use {} (argument buffer {}) because its pjrt buffer has been donated", .{ buffer, ctx.index });
            const model_sharding = ctx.buffers.len;
            stdx.debug.assert(buffer._shards.len == model_sharding, "Can't feed a {d}-sharded tensor into a {d}-sharded model", .{ buffer._shards.len, ctx.buffers.len });
            stdx.debug.assert(ctx.shapes[ctx.index].eql(buffer.shape()), "Executable expected argument {} to have shape {f}, got {f}", .{ ctx.index, ctx.shapes[ctx.index], buffer.shape() });
            for (buffer._shards.constSlice(), 0..) |shard, d| {
                ctx.buffers[d][ctx.index] = shard;
            }
            ctx.index += 1;
        }
    }).cb, &context, v);
    return context.index;
}

fn prettyFnName(
    comptime func: anytype,
    allocator: std.mem.Allocator,
) !std.ArrayList(u8) {
    const full_noisy_name = @typeName(@TypeOf(func));
    const og_len = full_noisy_name.len;
    const buffer = try allocator.alloc(u8, og_len);
    errdefer comptime unreachable; // No errors below this point.
    var out: []u8 = buffer;

    {
        const verbose = "tensor.Tensor";
        const compact = "Tensor";
        const num_replacements = std.mem.replace(u8, full_noisy_name, verbose, compact, buffer);
        out.len = out.len + num_replacements * compact.len - num_replacements * verbose.len;
    }

    {
        const verbose = "tensor.Tensor.";
        const compact = "";
        const num_replacements = std.mem.replace(u8, out, verbose, compact, buffer);
        out.len = out.len + num_replacements * compact.len - num_replacements * verbose.len;
    }

    {
        const verbose = "shape.Shape";
        const compact = "Shape";
        const num_replacements = std.mem.replace(u8, out, verbose, compact, buffer);
        out.len = out.len + num_replacements * compact.len - num_replacements * verbose.len;
    }

    return .{ .items = out, .capacity = og_len };
}
